library(broom)
library(MASS)
library(tidyverse)
Loading tidyverse: ggplot2
Loading tidyverse: tibble
Loading tidyverse: tidyr
Loading tidyverse: readr
Loading tidyverse: purrr
Loading tidyverse: dplyr
Conflicts with tidy packages --------------------------------------------------------
filter(): dplyr, stats
lag(): dplyr, stats
select(): dplyr, MASS
library(ggthemes)
library(plotly)
Attaching package: ‘plotly’
The following object is masked from ‘package:ggplot2’:
last_plot
The following object is masked from ‘package:MASS’:
select
The following object is masked from ‘package:stats’:
filter
The following object is masked from ‘package:graphics’:
layout
library(modelr)
Attaching package: ‘modelr’
The following object is masked from ‘package:broom’:
bootstrap
library(DT)
library(ggrepel)
library(rpart)
library(ggplot2)
notebook_theme <- theme_fivethirtyeight() +
theme( axis.title = element_text(), legend.position = "right",
legend.direction = "vertical")
theme_set(notebook_theme)
cars <- read.csv("cars.csv")
head(cars)
# datatable(cars, options = list())
Lets make a simple tree model
tree <- rpart (type ~ kmpl + bhp + price, cars, method = "class")
summary(tree)
Call:
rpart(formula = type ~ kmpl + bhp + price, data = cars, method = "class")
n= 42
CP nsplit rel error xerror xstd
1 0.72222222 0 1.0000000 1.0000000 0.1781742
2 0.05555556 1 0.2777778 0.5555556 0.1533479
3 0.01000000 2 0.2222222 0.6111111 0.1582997
Variable importance
price bhp kmpl
45 31 24
Node number 1: 42 observations, complexity param=0.7222222
predicted class=Hatchback expected loss=0.4285714 P(node) =1
class counts: 24 18
probabilities: 0.571 0.429
left son=2 (19 obs) right son=3 (23 obs)
Primary splits:
price < 474.5 to the left, improve=12.74534, (0 missing)
bhp < 88 to the left, improve=10.97143, (0 missing)
kmpl < 17.65 to the right, improve= 5.43254, (0 missing)
Surrogate splits:
kmpl < 17.85 to the right, agree=0.810, adj=0.579, (0 split)
bhp < 74.5 to the left, agree=0.786, adj=0.526, (0 split)
Node number 2: 19 observations
predicted class=Hatchback expected loss=0 P(node) =0.452381
class counts: 19 0
probabilities: 1.000 0.000
Node number 3: 23 observations, complexity param=0.05555556
predicted class=Sedan expected loss=0.2173913 P(node) =0.547619
class counts: 5 18
probabilities: 0.217 0.783
left son=6 (9 obs) right son=7 (14 obs)
Primary splits:
bhp < 84 to the left, improve=3.3816430, (0 missing)
price < 543 to the left, improve=2.3715420, (0 missing)
kmpl < 16.25 to the right, improve=0.2094203, (0 missing)
Surrogate splits:
price < 599 to the left, agree=0.826, adj=0.556, (0 split)
kmpl < 17.65 to the right, agree=0.652, adj=0.111, (0 split)
Node number 6: 9 observations
predicted class=Hatchback expected loss=0.4444444 P(node) =0.2142857
class counts: 5 4
probabilities: 0.556 0.444
Node number 7: 14 observations
predicted class=Sedan expected loss=0 P(node) =0.3333333
class counts: 0 14
probabilities: 0.000 1.000
grid <- cars %>%
data_grid(kmpl = seq_range(kmpl, 25),
bhp = seq_range(bhp, 25), price=seq_range(price,25))
pred <- predict(tree, grid, type="prob") %>%
as.data.frame() %>%
dplyr::select(Hatchback)
tree_grid <- bind_cols(grid, pred)
colnames(tree_grid)[4] <- "pred"
head(tree_grid)
NA
p <- plot_ly(tree_grid, x = ~kmpl, y = ~bhp, z = ~price, color = ~pred) %>%
add_markers() %>%
layout(scene = list(xaxis = list(title = 'kmpl'),
yaxis = list(title = 'bhp'),
zaxis = list(title = 'price')))
p